{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### imports\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch, math\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "from torch import optim\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "\n",
    "from transformers import AdamW, get_linear_schedule_with_warmup\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "from gensim import corpora, similarities, models\n",
    "from gensim.matutils import corpus2csc\n",
    "from sklearn.model_selection import KFold\n",
    "from lightgbm import LGBMClassifier\n",
    "\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import base64\n",
    "from tqdm.auto import tqdm\n",
    "import pickle, random\n",
    "import matplotlib.pyplot as plt\n",
    "from multiprocessing import Pool\n",
    "import sys, csv, json, os, gc, time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### parameters\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:3\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fa08c07e3d0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "no = '3'\n",
    "device = torch.device('cuda:'+no) if torch.cuda.is_available() else torch.device('cpu')\n",
    "print(device)\n",
    "\n",
    "k = 10\n",
    "lr = 1e-5\n",
    "# true batch size = batch_size * grad_step\n",
    "batch_size = 64\n",
    "margin = 8\n",
    "grad_step = 1\n",
    "max_lang_len = 15\n",
    "max_img_len = 30\n",
    "epochs = 1\n",
    "MOD = 20000\n",
    "shuffle_fold = True\n",
    "seed = 9527\n",
    "random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class params:\n",
    "    IMG_FEAT_SIZE = 2048+6\n",
    "    WORD_EMBED_SIZE = 1024\n",
    "    LAYER = 6\n",
    "    HIDDEN_SIZE = 1024\n",
    "    MULTI_HEAD = 8\n",
    "    DROPOUT_R = 0.1\n",
    "    FLAT_MLP_SIZE = 512\n",
    "    FLAT_GLIMPSES = 1\n",
    "    FLAT_OUT_SIZE = 2048\n",
    "    FF_SIZE = HIDDEN_SIZE*4\n",
    "    HIDDEN_SIZE_HEAD = HIDDEN_SIZE // MULTI_HEAD\n",
    "    OPT_BETAS = (0.9, 0.98)\n",
    "    OPT_EPS = 1e-9\n",
    "    TRAIN_SIZE = 3000000\n",
    "\n",
    "__C = params()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load data\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "trash = {'!', '$', \"'ll\", \"'s\", ',', '&', ':', 'and', 'cut', 'is', 'are', 'was'}\n",
    "trash_replace = ['\"hey siri, play some', 'however, ', 'yin and yang, ',\n",
    "                 'shopping mall/']\n",
    "\n",
    "def process(x):\n",
    "    tmp = x.split()\n",
    "    if tmp[0] in trash: x = ' '.join(tmp[1:])\n",
    "    if tmp[0][0] == '-': x = x[1:]\n",
    "    for tr in trash_replace:\n",
    "        x = x.replace(tr, '')\n",
    "    return x\n",
    "\n",
    "def normalize(x):\n",
    "    ret = x['boxes'].copy()\n",
    "    ret[:,0] /= x['image_h']\n",
    "    ret[:,1] /= x['image_w']\n",
    "    ret[:,2] /= x['image_h']\n",
    "    ret[:,3] /= x['image_w']\n",
    "    wh = (ret[:,2]-ret[:,0]) * (ret[:,3]-ret[:,1])\n",
    "    wh2 = (ret[:,2]-ret[:,0]) / (ret[:,3]-ret[:,1]+1e-6)\n",
    "    ret = np.hstack((ret, wh.reshape(-1,1), wh2.reshape(-1,1)))\n",
    "    return ret\n",
    "\n",
    "def sort_by_area(x):\n",
    "    return np.array(sorted(x.tolist(), key=lambda x: x[-1], reverse=True))\n",
    "\n",
    "def load_data(file_name, reset=False, decode=True):\n",
    "    ret = pd.read_csv(file_name, sep='\\t')\n",
    "    if decode:\n",
    "        ret['boxes'] = ret['boxes'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.float32).reshape(-1, 4))\n",
    "        ret['features'] = ret['features'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.float32).reshape(-1, 2048))\n",
    "        ret['class_labels'] = ret['class_labels'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.int64).reshape(-1, 1))\n",
    "        ret['boxes'] = ret.apply(lambda x: normalize(x), axis=1)\n",
    "        ret['features'] = ret.apply(lambda x: np.concatenate((x['features'], x['boxes']), axis=1)[:max_img_len], axis=1)\n",
    "    ret['query'] = ret['query'].apply(lambda x: process(x))\n",
    "    # reset query_id\n",
    "    if reset:\n",
    "        query2qid = {query: qid for qid, (query, _) in enumerate(tqdm(ret.groupby('query')))}\n",
    "        ret['query_id'] = ret.apply(lambda x: query2qid[x['query']], axis=1)\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './'\n",
    "test = load_data(path+'valid.tsv')\n",
    "testA = load_data(path+'testB.tsv')\n",
    "answers = json.loads(open(path+'valid_answer.json', 'r').read())\n",
    "test['target'] = test.apply(lambda x: 1 if x['product_id'] in answers[str(x['query_id'])] else 0, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### preprocess\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c4021ad6b9c34bb9a38e405158a15fe1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff8218147cde482a9d84bf7cba4ecdc6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=994.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# load pre-trained model\n",
    "take = 'bert-large-uncased'\n",
    "emb_size = __C.WORD_EMBED_SIZE\n",
    "tokenizer = AutoTokenizer.from_pretrained(take)\n",
    "pretrained_emb = AutoModel.from_pretrained(take)\n",
    "pad_id = tokenizer.pad_token_id\n",
    "sep_id = tokenizer.sep_token_id\n",
    "\n",
    "qid2token = {qid: tokenizer.encode(group['query'].values[0]) for qid, group in tqdm(test.groupby('query_id'))}\n",
    "test['token'] = test['query_id'].apply(lambda x: qid2token[x])\n",
    "test['token'] = test['token'].apply(lambda x: x[:max_lang_len-1]+[sep_id] if len(x) > max_lang_len else x)\n",
    "qid2token = {qid: tokenizer.encode(group['query'].values[0]) for qid, group in tqdm(testA.groupby('query_id'))}\n",
    "testA['token'] = testA['query_id'].apply(lambda x: qid2token[x])\n",
    "testA['token'] = testA['token'].apply(lambda x: x[:max_lang_len-1]+[sep_id] if len(x) > max_lang_len else x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### model\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FC(nn.Module):\n",
    "    def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):\n",
    "        super(FC, self).__init__()\n",
    "        self.dropout_r = dropout_r\n",
    "        self.use_relu = use_relu\n",
    "\n",
    "        self.linear = nn.Linear(in_size, out_size)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.dropout = nn.Dropout(dropout_r)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.dropout(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):\n",
    "        super(MLP, self).__init__()\n",
    "\n",
    "        self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)\n",
    "        self.linear = nn.Linear(mid_size, out_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(self.fc(x))\n",
    "\n",
    "\n",
    "class VisualBERT(nn.Module):\n",
    "    def __init__(self, __C, bert):\n",
    "        super(VisualBERT, self).__init__()\n",
    "        \n",
    "        self.bert = bert\n",
    "        self.linear_img = nn.Linear(__C.IMG_FEAT_SIZE, __C.WORD_EMBED_SIZE)\n",
    "        self.out = MLP(__C.WORD_EMBED_SIZE, __C.WORD_EMBED_SIZE//2, 1)\n",
    "        \n",
    "    def forward(self, ques_ix, img_feats):\n",
    "        proj_feats = []\n",
    "        for img_feat in img_feats:\n",
    "            # Make mask & token type ids\n",
    "            mask = self.make_mask(ques_ix.unsqueeze(2), img_feat, pad_id)\n",
    "            token = self.get_token_type(ques_ix, img_feat)\n",
    "            # Preprocess features\n",
    "            lang_feat = self.bert.embeddings.word_embeddings(ques_ix)\n",
    "            img_feat = self.linear_img(img_feat)\n",
    "            combine_feat = torch.cat((lang_feat, img_feat), dim=1)\n",
    "            # Token embeddings & position embeddings\n",
    "            position_ids = torch.arange(token.size(1), dtype=torch.long, device=device)\n",
    "            position_ids = position_ids.unsqueeze(0).expand(token.size())\n",
    "            position_embeddings = self.bert.embeddings.position_embeddings(position_ids)\n",
    "            token_type_embeddings = self.bert.embeddings.token_type_embeddings(token)\n",
    "            # Add all\n",
    "            embeddings = combine_feat+position_embeddings+token_type_embeddings\n",
    "            embeddings = self.bert.embeddings.dropout(self.bert.embeddings.LayerNorm(embeddings))\n",
    "            # Go through the rest of BERT\n",
    "            head_mask = self.bert.get_head_mask(None, self.bert.config.num_hidden_layers)\n",
    "            extended_attention_mask = self.bert.get_extended_attention_mask(mask, mask.size(), device)\n",
    "            encoder_outputs = self.bert.encoder(embeddings,\n",
    "                                                attention_mask=extended_attention_mask,\n",
    "                                                head_mask=head_mask,\n",
    "                                                encoder_hidden_states=None,\n",
    "                                                encoder_attention_mask=None)\n",
    "            # CLS embedding & output value\n",
    "            outputs = encoder_outputs[0][:,0,:]\n",
    "            proj_feats.append(self.out(outputs))\n",
    "        return proj_feats\n",
    "            \n",
    "    # Masking\n",
    "    def make_mask(self, lang_feat, img_feat, target):\n",
    "        # 1 for NOT masked; 0 for masked\n",
    "        # [batch, len]\n",
    "        lang_mask = (torch.sum(torch.abs(lang_feat), dim=-1) != target).long()\n",
    "        img_mask = (torch.sum(torch.abs(img_feat), dim=-1) != 0).long()\n",
    "        return torch.cat((lang_mask, img_mask), dim=1)\n",
    "    \n",
    "    # Token type ids\n",
    "    def get_token_type(self, lang_feat, img_feat):\n",
    "        #    lang      img\n",
    "        # 0 0 0 0 0 0 1 1 1 1 1\n",
    "        lang_token = torch.zeros(lang_feat.size(0), lang_feat.size(1)).to(device)\n",
    "        img_token = torch.ones(img_feat.size(0), img_feat.size(1)).to(device)\n",
    "        return torch.cat((lang_token, img_token), dim=1).long()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(data.Dataset):\n",
    "    def __init__(self, train_x):\n",
    "        self.train_x = train_x\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        tokens, pos_features, neg_features = self.train_x[index][0], self.train_x[index][1], self.train_x[index][2]\n",
    "        return [tokens, pos_features, neg_features]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.train_x)\n",
    "    \n",
    "def collate_fn(batch):\n",
    "    tokens, pos_features, neg_features = zip(*batch)\n",
    "    max_len_t, max_len_pf, max_len_nf = len(max(tokens, key=lambda x: len(x))), len(max(pos_features, key=lambda x: len(x))), len(max(neg_features, key=lambda x: len(x)))\n",
    "    tokens, pos_features, neg_features = [token+[pad_id]*(max_len_t-len(token)) for token in tokens], [np.concatenate((feature, np.zeros((max_len_pf-feature.shape[0], feature.shape[1]))), axis=0) for feature in pos_features], [np.concatenate((feature, np.zeros((max_len_nf-feature.shape[0], feature.shape[1]))), axis=0) for feature in neg_features]\n",
    "    return torch.LongTensor(tokens), torch.FloatTensor(pos_features), torch.FloatTensor(neg_features)\n",
    "\n",
    "def custom_schedule(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, amplitude=0.1, last_epoch=-1):\n",
    "    \n",
    "    def lr_lambda(current_step):\n",
    "        if current_step < num_warmup_steps:\n",
    "            return float(current_step) / float(max(1, num_warmup_steps))\n",
    "        progress = 2.0 * math.pi * float(num_cycles) * float(current_step-num_warmup_steps) / float(max(1, num_training_steps-num_warmup_steps))\n",
    "        linear = float(num_training_steps-current_step) / float(max(1, num_training_steps-num_warmup_steps))\n",
    "        return abs(linear + math.sin(progress)*linear*amplitude)\n",
    "\n",
    "    return LambdaLR(optimizer, lr_lambda, last_epoch)\n",
    "\n",
    "def shuffle(x):\n",
    "    idxs = [i for i in range(x.shape[0])]\n",
    "    random.shuffle(idxs)\n",
    "    return x[idxs]\n",
    "\n",
    "def nDCG_score(preds, answers):\n",
    "    iDCG = sum([sum([np.log(2)/np.log(i+2) for i in range(min(len(answer), 5))]) \\\n",
    "                for answer in list(answers.values())])\n",
    "    DCG = sum([sum([np.log(2)/np.log(i+2) if preds[qid][i] in answers[str(qid)] else 0 \\\n",
    "                    for i in range(len(preds[qid]))]) for qid in list(preds.keys())])\n",
    "    return DCG/iDCG\n",
    "\n",
    "class FocalLoss(nn.Module):\n",
    "    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):\n",
    "        super(FocalLoss, self).__init__()\n",
    "        self.alpha = alpha\n",
    "        self.gamma = gamma\n",
    "        self.logits = logits\n",
    "        self.reduce = reduce\n",
    "\n",
    "    def forward(self, inputs, targets):\n",
    "        if self.logits:\n",
    "            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)\n",
    "        else:\n",
    "            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)\n",
    "        pt = torch.exp(-BCE_loss)\n",
    "        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss\n",
    "\n",
    "        if self.reduce:\n",
    "            return torch.mean(F_loss)\n",
    "        else:\n",
    "            return F_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initializing model...\n"
     ]
    }
   ],
   "source": [
    "print('initializing model...')\n",
    "nDCGs = []\n",
    "best_nDCG = 0.0\n",
    "model = VisualBERT(__C, pretrained_emb).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### prediction\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_embedding(model, ques_ix, img_feat):\n",
    "    # Make mask & token type ids\n",
    "    mask = model.make_mask(ques_ix.unsqueeze(2), img_feat, pad_id)\n",
    "    token = model.get_token_type(ques_ix, img_feat)\n",
    "    # Preprocess features\n",
    "    lang_feat = model.bert.embeddings.word_embeddings(ques_ix)\n",
    "    img_feat = model.linear_img(img_feat)\n",
    "    combine_feat = torch.cat((lang_feat, img_feat), dim=1)\n",
    "    # Token embeddings & position embeddings\n",
    "    position_ids = torch.arange(token.size(1), dtype=torch.long, device=device)\n",
    "    position_ids = position_ids.unsqueeze(0).expand(token.size())\n",
    "    position_embeddings = model.bert.embeddings.position_embeddings(position_ids)\n",
    "    token_type_embeddings = model.bert.embeddings.token_type_embeddings(token)\n",
    "    # Add all\n",
    "    embeddings = combine_feat+position_embeddings+token_type_embeddings\n",
    "    embeddings = model.bert.embeddings.dropout(model.bert.embeddings.LayerNorm(embeddings))\n",
    "    # Go through the rest of BERT\n",
    "    head_mask = model.bert.get_head_mask(None, model.bert.config.num_hidden_layers)\n",
    "    extended_attention_mask = model.bert.get_extended_attention_mask(mask, mask.size(), device)\n",
    "    encoder_outputs = model.bert.encoder(embeddings,\n",
    "                                         attention_mask=extended_attention_mask,\n",
    "                                         head_mask=head_mask,\n",
    "                                         encoder_hidden_states=None,\n",
    "                                         encoder_attention_mask=None)\n",
    "    # CLS embedding & output value\n",
    "    outputs = encoder_outputs[0][:,0,:]\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cls(model, n_splits):\n",
    "    model.eval()\n",
    "    qids = [[qid] for qid, _ in test.groupby('query_id')]\n",
    "    kf = KFold(n_splits=n_splits, shuffle=True)\n",
    "    train_x, train_y = [], []\n",
    "    test_x, test_y = [], []\n",
    "    qid2fold = {qids[idx][0]: i \\\n",
    "                for i, (train_index, test_index) in enumerate(kf.split(qids)) \\\n",
    "                for idx in test_index}\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for qid, group in tqdm(test.groupby('query_id')):\n",
    "            # prepare batch\n",
    "            tokens, features = group['token'].values.tolist(), group['features'].values.tolist()\n",
    "            max_len_f = len(max(features, key=lambda x: len(x)))\n",
    "            features = [np.concatenate((feature, np.zeros((max_len_f-feature.shape[0], feature.shape[1]))), axis=0) for feature in features]\n",
    "            # # to tensor\n",
    "            tokens = torch.LongTensor(tokens).to(device)\n",
    "            features = torch.FloatTensor(features).to(device)\n",
    "            # predict\n",
    "            tmp_x = extract_embedding(model, tokens, features).tolist()\n",
    "            tmp_y = group['target'].values.tolist()\n",
    "            # only use first fold\n",
    "            if qid2fold[qid]:\n",
    "                train_x += tmp_x\n",
    "                train_y += tmp_y\n",
    "            else:\n",
    "                test_x += tmp_x\n",
    "                test_y += tmp_y\n",
    "    \n",
    "    train_x, train_y = np.array(train_x), np.array(train_y)\n",
    "    test_x, test_y = np.array(test_x), np.array(test_y)\n",
    "    print('train:test = {}:{}'.format(train_x.shape[0], test_x.shape[0]))\n",
    "    cls = LGBMClassifier(random_state=0, n_jobs=24)\n",
    "    cls.fit(train_x, train_y,\n",
    "            eval_set=[(test_x, test_y)],\n",
    "            early_stopping_rounds=100,\n",
    "            verbose=100)\n",
    "        \n",
    "    return cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model, test, thd, cls):\n",
    "    model.eval()\n",
    "    counts = Counter(test['product_id'].values.tolist())\n",
    "    preds = {}\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for qid, group in tqdm(test.groupby('query_id')):\n",
    "            # prepare batch\n",
    "            tokens, features = group['token'].values.tolist(), group['features'].values.tolist()\n",
    "            max_len_f = len(max(features, key=lambda x: len(x)))\n",
    "            features = [np.concatenate((feature, np.zeros((max_len_f-feature.shape[0], feature.shape[1]))), axis=0) for feature in features]\n",
    "            # # to tensor\n",
    "            tokens = torch.LongTensor(tokens).to(device)\n",
    "            features = torch.FloatTensor(features).to(device)\n",
    "            # predict\n",
    "            embeddings = np.array(extract_embedding(model, tokens, features).tolist())\n",
    "            out = cls.predict_proba(embeddings)[:,1]\n",
    "            pred = [(pid, val) for pid, val in zip(group['product_id'].values.tolist(), out.tolist())]\n",
    "            thd_tmp = thd\n",
    "            pred2 = [x for x in pred if counts[x[0]] <= thd_tmp]\n",
    "            while len(pred2) < 5:\n",
    "                thd_tmp += 1\n",
    "                pred2 = [x for x in pred if counts[x[0]] <= thd_tmp]\n",
    "            pred2.sort(key=lambda x: x[1], reverse=True)\n",
    "            preds[qid] = [pid for pid, _ in pred2[:5]]\n",
    "            \n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_all(model, test, pad_len, cls):\n",
    "    model.eval()\n",
    "    preds = {}\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for qid, group in tqdm(test.groupby('query_id')):\n",
    "            # prepare batch\n",
    "            tokens, features = group['token'].values.tolist(), group['features'].values.tolist()\n",
    "            max_len_f = len(max(features, key=lambda x: len(x)))\n",
    "            features = [np.concatenate((feature, np.zeros((max_len_f-feature.shape[0], feature.shape[1]))), axis=0) for feature in features]\n",
    "            # # to tensor\n",
    "            tokens = torch.LongTensor(tokens).to(device)\n",
    "            features = torch.FloatTensor(features).to(device)\n",
    "            # predict\n",
    "            embeddings = np.array(extract_embedding(model, tokens, features).tolist())\n",
    "            out = cls.predict_proba(embeddings)[:,1]\n",
    "            pred = [(pid, val) for pid, val in zip(group['product_id'].values.tolist(), out.tolist())]\n",
    "            pred.sort(key=lambda x: x[1], reverse=True)\n",
    "            assert len(pred) <= pad_len\n",
    "            pid, score = [p for p, s in pred], [s for p, s in pred]\n",
    "            pid, score = pid+[np.nan]*(pad_len-len(pred)), score+[np.nan]*(pad_len-len(pred))\n",
    "            preds[qid] = pid+score\n",
    "            \n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed: 7328; fold: 30\n",
      "seed: 7328; fold: 31\n",
      "seed: 7328; fold: 32\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0afbc35ddc74be0acc1a5a5506fbfee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train:test = 13250:1470\n",
      "Training until validation scores don't improve for 100 rounds\n",
      "[100]\tvalid_0's binary_logloss: 0.307157\n",
      "Did not meet early stopping. Best iteration is:\n",
      "[100]\tvalid_0's binary_logloss: 0.307157\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd48fbb1de4c4d0782a0e2237bf94410",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "nDCG@5: 0.9743776440774257\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60ab4a988f6d41a6b40caa4a0f3adbf7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=994.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "seed: 7328; fold: 33\n",
      "seed: 7328; fold: 34\n",
      "seed: 7328; fold: 35\n",
      "seed: 7328; fold: 36\n",
      "seed: 7328; fold: 37\n",
      "seed: 7328; fold: 38\n",
      "seed: 7328; fold: 39\n",
      "time consumed: 3 min 40 sec\n",
      "seed: 62; fold: 30\n",
      "seed: 62; fold: 31\n",
      "seed: 62; fold: 32\n",
      "seed: 62; fold: 33\n",
      "seed: 62; fold: 34\n",
      "seed: 62; fold: 35\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32e60598856446f18327bb171eebeb55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train:test = 13233:1487\n",
      "Training until validation scores don't improve for 100 rounds\n",
      "[100]\tvalid_0's binary_logloss: 0.301998\n",
      "Did not meet early stopping. Best iteration is:\n",
      "[97]\tvalid_0's binary_logloss: 0.301317\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa4d486991f94bd7ba47b83c23d13cb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "nDCG@5: 0.9747210942496021\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d3d9d142fb04268a6a8e109e8ca184a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=994.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "seed: 62; fold: 36\n",
      "seed: 62; fold: 37\n",
      "seed: 62; fold: 38\n",
      "seed: 62; fold: 39\n",
      "time consumed: 3 min 40 sec\n",
      "seed: 3951; fold: 30\n",
      "seed: 3951; fold: 31\n",
      "seed: 3951; fold: 32\n",
      "seed: 3951; fold: 33\n",
      "seed: 3951; fold: 34\n",
      "seed: 3951; fold: 35\n",
      "seed: 3951; fold: 36\n",
      "seed: 3951; fold: 37\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00c10c402f684774bff59d4f71edd1d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train:test = 13229:1491\n",
      "Training until validation scores don't improve for 100 rounds\n",
      "[100]\tvalid_0's binary_logloss: 0.335173\n",
      "Did not meet early stopping. Best iteration is:\n",
      "[98]\tvalid_0's binary_logloss: 0.334506\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a21b364106ac4e3998b7ea155df1e49a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "nDCG@5: 0.9694379885062698\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88044c2863dc4d26bb15d44214b35592",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=994.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "seed: 3951; fold: 38\n",
      "seed: 3951; fold: 39\n",
      "time consumed: 3 min 41 sec\n",
      "seed: 9736; fold: 30\n",
      "seed: 9736; fold: 31\n",
      "seed: 9736; fold: 32\n",
      "seed: 9736; fold: 33\n",
      "seed: 9736; fold: 34\n",
      "seed: 9736; fold: 35\n",
      "seed: 9736; fold: 36\n",
      "seed: 9736; fold: 37\n",
      "seed: 9736; fold: 38\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e03ccfd2c56247d89a0c68912870bcb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train:test = 13229:1491\n",
      "Training until validation scores don't improve for 100 rounds\n",
      "[100]\tvalid_0's binary_logloss: 0.316825\n",
      "Did not meet early stopping. Best iteration is:\n",
      "[100]\tvalid_0's binary_logloss: 0.316825\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86a14ace1e3449a1a0c98d40d28c64f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=496.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "nDCG@5: 0.9729549296935117\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a44e72aace2142bd8b8a6cf326161bc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=994.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "seed: 9736; fold: 39\n",
      "time consumed: 3 min 42 sec\n"
     ]
    }
   ],
   "source": [
    "seeds = {7328: 32, 62: 35, 3951: 37, 9736: 38,\n",
    "         65: 0, 332: 0, 206: 0,\n",
    "         9876: 0, 888: 0, 333: 0}\n",
    "folds = [i for i in range(30, 40)]\n",
    "pad_len = 30\n",
    "thd = 9999\n",
    "n_splits = 10\n",
    "\n",
    "for seed in list(seeds.keys()):\n",
    "    if not seeds[seed]: continue\n",
    "    t0 = time.time()\n",
    "    for fold in folds:\n",
    "        print('seed: {}; fold: {}'.format(seed, fold))\n",
    "        # load model weights\n",
    "        try: model.load_state_dict(torch.load(path+'models/model_Visual-BERT_pair_box_tfidf-neg_focal_all_{}_{}'.format(seed, fold), map_location=device))\n",
    "        except: continue\n",
    "        # train cls\n",
    "        cls = get_cls(model, n_splits)\n",
    "        # valid\n",
    "        preds = predict(model, test, thd, cls)\n",
    "        nDCG = nDCG_score(preds, answers)\n",
    "        print('nDCG@5:', nDCG)\n",
    "        # test\n",
    "        preds = predict_all(model, testA, pad_len, cls)\n",
    "        # write to file\n",
    "        header = ['qid'] + ['p'+str(i) for i in range(pad_len)] + ['s'+str(i) for i in range(pad_len)]\n",
    "        with open('predictions/prediction_all_cls2_{}_{}.csv'.format(seed, fold), 'w', newline='') as f:\n",
    "            w = csv.writer(f)\n",
    "            w.writerow(header)\n",
    "            for qid in sorted(list(preds.keys())):\n",
    "                w.writerow([qid]+preds[qid])\n",
    "    t = round(time.time()-t0)\n",
    "    print('time consumed: {} min {} sec'.format(t//60, t%60))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
